{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0f0e48de",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-19T12:06:52.154221Z",
     "start_time": "2024-02-19T12:06:52.146740Z"
    }
   },
   "outputs": [],
   "source": [
    "# !pip install wandb\n",
    "# !wandb login"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc437c19-a95a-4491-b2a9-9cf59917571e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-17T14:08:33.203436Z",
     "iopub.status.busy": "2024-12-17T14:08:33.202778Z",
     "iopub.status.idle": "2024-12-17T14:08:33.225211Z",
     "shell.execute_reply": "2024-12-17T14:08:33.222790Z",
     "shell.execute_reply.started": "2024-12-17T14:08:33.203386Z"
    }
   },
   "source": [
    "```\n",
    "import wandb\n",
    "\n",
    "NUM_EPOCHS = 100\n",
    "BATCH_SIZE = 1000\n",
    "NUM_TOKENS = 10\n",
    "LR = 1e-5\n",
    "KL_FACTOR = 6000\n",
    "WANDB = False\n",
    "\n",
    "if WANDB:\n",
    "    run = wandb.init(\n",
    "        project=\"tinycatstories\",\n",
    "        config={\n",
    "            \"epochs\": NUM_EPOCHS,\n",
    "            \"batch_size\": BATCH_SIZE,\n",
    "            \"num_tokens\": NUM_TOKENS,\n",
    "            \"learning_rate\": LR,\n",
    "            \"kl_factor\": KL_FACTOR,\n",
    "        },\n",
    "    )\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15111c13",
   "metadata": {},
   "source": [
    "## basics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "21b7e672",
   "metadata": {},
   "source": [
    "```\n",
    "import wandb\n",
    "wandb.init()\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1b4c9871",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:01:39.919238Z",
     "start_time": "2023-04-23T14:01:38.315799Z"
    }
   },
   "outputs": [],
   "source": [
    "import argparse\n",
    "import random # to set the python random seed\n",
    "import numpy # to set the numpy random seed\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torchvision import models\n",
    "from datetime import datetime\n",
    "# Ignore excessive warnings\n",
    "import logging\n",
    "logging.propagate = False \n",
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "\n",
    "# WandB – Import the wandb library\n",
    "import wandb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdbfacf7",
   "metadata": {},
   "source": [
    "## summary"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97b070f6",
   "metadata": {},
   "source": [
    "```\n",
    "!pip install wandb\n",
    "wandb login\n",
    "# api_key\n",
    "~/.netrc\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d250d6e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T12:20:22.753328Z",
     "start_time": "2023-04-23T12:20:22.747238Z"
    }
   },
   "source": [
    "- WandB: weights & biases\n",
    "\n",
    "```\n",
    "wandb.init(project=\"wandb-demo-0423\")\n",
    "# 字典（dict）\n",
    "config = wandb.config\n",
    "config[k] = v\n",
    "\n",
    "# 实例化模型\n",
    "model = Net().to(device)\n",
    "train_dataset\n",
    "test_dataset\n",
    "train_dataloader\n",
    "test_dataloader\n",
    "\n",
    "# 监控模型，histogram weights and biases\n",
    "wandb.watch(model, log=\"all\")\n",
    "\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    train_loss, train_acc = train(model, train_dataloader)\n",
    "    # 字典的形式\n",
    "    wandb.log({\"train_loss\": train_loss, \"train_acc\": train_acc})\n",
    "    # 评估，不进行参数的更新\n",
    "    test_loss, test_acc = test(model, test_dataloader)\n",
    "    wandb.log({\"test_loss\": test_loss, \"test_acc\": train_acc})\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bda83126",
   "metadata": {},
   "source": [
    "## model, train & test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4d1c2f12",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:01:46.208111Z",
     "start_time": "2023-04-23T14:01:46.203295Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(train_dataloader, model, criterion, optimizer, device):\n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    total_batch = len(train_dataloader)\n",
    "    for batch_idx, (images, labels) in enumerate(train_dataloader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # forward\n",
    "        out = model(images)\n",
    "        loss = criterion(out, labels)\n",
    "\n",
    "        # 标准的处理，用 validate data；这个过程是监督训练过程，用于 early stop\n",
    "        n_corrects = (out.argmax(axis=1) == labels).sum().item()\n",
    "        acc = n_corrects/labels.size(0)\n",
    "\n",
    "        # backward\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()   # 更细 模型参数\n",
    "        \n",
    "        total_loss += loss.item()\n",
    "        total_correct += n_corrects\n",
    "        \n",
    "        if (batch_idx+1) % 200 == 0:\n",
    "            print(f'{datetime.now()}, {batch_idx+1}/{total_batch}: {loss.item():.4f}, acc: {acc}')\n",
    "    total_errors = len(train_dataloader.dataset) - total_correct\n",
    "    return total_loss, total_correct/len(train_dataloader.dataset), total_errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b7f0cdfb",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:02:03.797083Z",
     "start_time": "2023-04-23T14:02:03.784871Z"
    }
   },
   "outputs": [],
   "source": [
    "def test(test_dataloader, model, criterion, device, classes):\n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    example_images = []\n",
    "    model.eval()\n",
    "    for images, labels in test_dataloader:\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        out = model(images)\n",
    "        loss = criterion(out, labels)\n",
    "        total_loss += loss.item()\n",
    "        preds = torch.argmax(out, dim=1)\n",
    "        total_correct += (preds == labels).sum().item()\n",
    "        \n",
    "        mis_preds_indice = torch.flatten((preds != labels).nonzero())\n",
    "        mis_preds = preds[mis_preds_indice]\n",
    "        mis_labels = labels[mis_preds_indice]\n",
    "        mis_images = images[mis_preds_indice]\n",
    "        \n",
    "        # 13*8 + 4 == 108\n",
    "        for index in range(len(mis_preds)):\n",
    "            example_images.append(wandb.Image(mis_images[index], \n",
    "                                              caption=\"Pred: {} Truth: {}\".format(classes[mis_preds[index].item()],\n",
    "                                                                                  classes[mis_labels[index]])))\n",
    "    total_errors = len(test_loader.dataset) - total_correct\n",
    "    return example_images, total_loss, total_correct / len(test_loader.dataset), total_errors\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86e4bf47",
   "metadata": {},
   "source": [
    "## wandb config & dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "20354ff6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:02:15.849162Z",
     "start_time": "2023-04-23T14:02:05.682841Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlanchunhui\u001b[0m (\u001b[33mloveresearch\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.0"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/whaow/workspaces/llm_aigc/tutorials/nn_basics/wandb/run-20230423_220207-dggdir0s</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/loveresearch/wandb-demo-0423/runs/dggdir0s' target=\"_blank\">ethereal-silence-10</a></strong> to <a href='https://wandb.ai/loveresearch/wandb-demo-0423' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/loveresearch/wandb-demo-0423' target=\"_blank\">https://wandb.ai/loveresearch/wandb-demo-0423</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/loveresearch/wandb-demo-0423/runs/dggdir0s' target=\"_blank\">https://wandb.ai/loveresearch/wandb-demo-0423/runs/dggdir0s</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"WANDB_API_KEY\"] = ''\n",
    "os.environ[\"WANDB_MODE\"] = \"online\"\n",
    "\n",
    "# WandB – Initialize a new run\n",
    "# 一个 project 可以 run 多次\n",
    "wandb.init(project=\"wandb-demo-0423\")\n",
    "wandb.watch_called = False # Re-run the model without restarting the runtime, unnecessary after our next release"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d130c023",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:02:17.680506Z",
     "start_time": "2023-04-23T14:02:17.673655Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# empty dict\n",
    "wandb.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5ac35385",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:02:18.655123Z",
     "start_time": "2023-04-23T14:02:18.641752Z"
    }
   },
   "outputs": [],
   "source": [
    "# WandB – Config is a variable that holds and saves hyperparameters and inputs\n",
    "config = wandb.config          # Initialize config\n",
    "config.batch_size = 64          # input batch size for training (default: 64)\n",
    "config.test_batch_size = 32    # input batch size for testing (default: 1000)\n",
    "config.epochs = 30             # number of epochs to train (default: 10)\n",
    "config.lr = 1e-3              # learning rate (default: 0.01)\n",
    "config.momentum = 0.9         # SGD momentum (default: 0.5) \n",
    "config.weight_decay = 5e-4\n",
    "config.no_cuda = False         # disables CUDA training\n",
    "config.seed = 42               # random seed (default: 42)\n",
    "config.log_interval = 10     # how many batches to wait before logging training status"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "30148831",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T15:10:30.432096Z",
     "start_time": "2023-04-23T15:10:30.423172Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "{'num_workers': 1, 'pin_memory': True}\n"
     ]
    }
   ],
   "source": [
    "use_cuda = not config.no_cuda and torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "print(device)\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "print(kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3c289c3d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:02:24.799152Z",
     "start_time": "2023-04-23T14:02:22.592645Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
      "  warnings.warn(\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "model = models.resnet18(pretrained=False)\n",
    "in_features = model.fc.in_features\n",
    "model.fc = nn.Linear(in_features, 10)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bba03e00",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:02:26.754527Z",
     "start_time": "2023-04-23T14:02:24.801556Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose([\n",
    "    transforms.Resize(size=(224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))\n",
    "])\n",
    "    \n",
    "# Now we load our training and test datasets and apply the transformations defined above\n",
    "train_dataset = datasets.CIFAR10(root='./data', train=True,\n",
    "                                 download=True, transform=transform)\n",
    "test_dataset = datasets.CIFAR10(root='./data', train=False,\n",
    "                                download=True, transform=transform)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, \n",
    "                                           batch_size=config.batch_size,\n",
    "                                           shuffle=True, \n",
    "                                           **kwargs)\n",
    "test_loader = torch.utils.data.DataLoader(test_dataset, \n",
    "                                          batch_size=config.test_batch_size,\n",
    "                                          shuffle=False, \n",
    "                                          **kwargs)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "635db4d6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T15:11:21.290317Z",
     "start_time": "2023-04-23T15:11:21.283595Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50000\n",
      "781\n",
      "782\n"
     ]
    }
   ],
   "source": [
    "print(len(train_dataset))\n",
    "print(len(train_dataset)//config.batch_size)\n",
    "print(len(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2c2e26f6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T15:11:50.081828Z",
     "start_time": "2023-04-23T15:11:50.069920Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n",
      "312\n",
      "313\n"
     ]
    }
   ],
   "source": [
    "print(len(test_dataset))\n",
    "print(len(test_dataset)//config.test_batch_size)\n",
    "print(len(test_loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5aba878c",
   "metadata": {},
   "source": [
    "## training pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "99a696a2",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-04-23T14:52:50.016509Z",
     "start_time": "2023-04-23T14:02:30.904713Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-04-23 22:02:47.546538, 200/782: 1.7578, acc: 0.3125\n",
      "2023-04-23 22:03:01.684265, 400/782: 1.5881, acc: 0.40625\n",
      "2023-04-23 22:03:15.786362, 600/782: 1.7024, acc: 0.359375\n",
      "\n",
      "2023-04-23 22:04:48.248923, epoch: 1, train_loss: 1312.3292, train_acc: 0.38, test_loss: 451.0153, test_acc: 0.47\n",
      "\n",
      "2023-04-23 22:05:03.120458, 200/782: 1.3696, acc: 0.578125\n",
      "2023-04-23 22:05:17.786350, 400/782: 1.7198, acc: 0.390625\n",
      "2023-04-23 22:05:31.756241, 600/782: 1.3919, acc: 0.46875\n",
      "\n",
      "2023-04-23 22:06:59.056604, epoch: 2, train_loss: 1111.5958, train_acc: 0.48, test_loss: 430.7115, test_acc: 0.50\n",
      "\n",
      "2023-04-23 22:07:14.636778, 200/782: 1.3315, acc: 0.453125\n",
      "2023-04-23 22:07:29.614013, 400/782: 1.1030, acc: 0.515625\n",
      "2023-04-23 22:07:43.630946, 600/782: 0.9127, acc: 0.5625\n",
      "\n",
      "2023-04-23 22:09:02.521576, epoch: 3, train_loss: 951.4182, train_acc: 0.56, test_loss: 373.4743, test_acc: 0.56\n",
      "\n",
      "2023-04-23 22:09:18.509211, 200/782: 1.1813, acc: 0.5625\n",
      "2023-04-23 22:09:32.542236, 400/782: 0.8315, acc: 0.734375\n",
      "2023-04-23 22:09:46.439042, 600/782: 0.9831, acc: 0.609375\n",
      "\n",
      "2023-04-23 22:10:56.286045, epoch: 4, train_loss: 848.2849, train_acc: 0.61, test_loss: 328.4867, test_acc: 0.62\n",
      "\n",
      "2023-04-23 22:11:11.895453, 200/782: 0.7470, acc: 0.765625\n",
      "2023-04-23 22:11:26.310563, 400/782: 0.8621, acc: 0.671875\n",
      "2023-04-23 22:11:40.374984, 600/782: 1.0024, acc: 0.625\n",
      "\n",
      "2023-04-23 22:12:48.117540, epoch: 5, train_loss: 763.4396, train_acc: 0.65, test_loss: 313.2866, test_acc: 0.64\n",
      "\n",
      "2023-04-23 22:13:02.603392, 200/782: 1.0213, acc: 0.640625\n",
      "2023-04-23 22:13:16.836407, 400/782: 0.9512, acc: 0.6875\n",
      "2023-04-23 22:13:31.229744, 600/782: 1.0002, acc: 0.609375\n",
      "\n",
      "2023-04-23 22:14:33.092854, epoch: 6, train_loss: 687.2216, train_acc: 0.69, test_loss: 273.0934, test_acc: 0.69\n",
      "\n",
      "2023-04-23 22:14:48.393379, 200/782: 0.8387, acc: 0.609375\n",
      "2023-04-23 22:15:02.561705, 400/782: 0.6446, acc: 0.734375\n",
      "2023-04-23 22:15:16.724601, 600/782: 0.8553, acc: 0.71875\n",
      "\n",
      "2023-04-23 22:16:20.240459, epoch: 7, train_loss: 631.4524, train_acc: 0.71, test_loss: 290.2783, test_acc: 0.68\n",
      "\n",
      "2023-04-23 22:16:34.831223, 200/782: 0.8965, acc: 0.6875\n",
      "2023-04-23 22:16:49.138514, 400/782: 0.6861, acc: 0.796875\n",
      "2023-04-23 22:17:03.360651, 600/782: 0.6669, acc: 0.796875\n",
      "\n",
      "2023-04-23 22:18:03.156384, epoch: 8, train_loss: 585.1868, train_acc: 0.73, test_loss: 258.8160, test_acc: 0.70\n",
      "\n",
      "2023-04-23 22:18:18.378913, 200/782: 0.8935, acc: 0.59375\n",
      "2023-04-23 22:18:32.730347, 400/782: 0.8117, acc: 0.71875\n",
      "2023-04-23 22:18:46.914735, 600/782: 0.5463, acc: 0.8125\n",
      "\n",
      "2023-04-23 22:19:44.052278, epoch: 9, train_loss: 533.3294, train_acc: 0.76, test_loss: 250.0995, test_acc: 0.72\n",
      "\n",
      "2023-04-23 22:19:58.823197, 200/782: 0.5807, acc: 0.765625\n",
      "2023-04-23 22:20:13.196541, 400/782: 0.9108, acc: 0.671875\n",
      "2023-04-23 22:20:27.714339, 600/782: 0.6214, acc: 0.765625\n",
      "\n",
      "2023-04-23 22:21:23.335821, epoch: 10, train_loss: 485.8444, train_acc: 0.78, test_loss: 241.6396, test_acc: 0.73\n",
      "\n",
      "2023-04-23 22:21:38.151306, 200/782: 0.4001, acc: 0.859375\n",
      "2023-04-23 22:21:53.485954, 400/782: 0.6767, acc: 0.796875\n",
      "2023-04-23 22:22:07.825476, 600/782: 0.5942, acc: 0.796875\n",
      "\n",
      "2023-04-23 22:23:01.450106, epoch: 11, train_loss: 441.0487, train_acc: 0.80, test_loss: 241.8771, test_acc: 0.74\n",
      "\n",
      "2023-04-23 22:23:16.125828, 200/782: 0.4643, acc: 0.859375\n",
      "2023-04-23 22:23:30.410816, 400/782: 0.6353, acc: 0.796875\n",
      "2023-04-23 22:23:44.746983, 600/782: 0.5426, acc: 0.890625\n",
      "\n",
      "2023-04-23 22:24:35.179366, epoch: 12, train_loss: 399.8553, train_acc: 0.82, test_loss: 217.5537, test_acc: 0.76\n",
      "\n",
      "2023-04-23 22:24:50.443131, 200/782: 0.2437, acc: 0.921875\n",
      "2023-04-23 22:25:06.149641, 400/782: 0.6050, acc: 0.859375\n",
      "2023-04-23 22:25:21.603247, 600/782: 0.3656, acc: 0.859375\n",
      "\n",
      "2023-04-23 22:26:13.198636, epoch: 13, train_loss: 362.7531, train_acc: 0.84, test_loss: 218.6354, test_acc: 0.76\n",
      "\n",
      "2023-04-23 22:26:28.201559, 200/782: 0.5647, acc: 0.765625\n",
      "2023-04-23 22:26:42.483938, 400/782: 0.2675, acc: 0.890625\n",
      "2023-04-23 22:26:56.691393, 600/782: 0.3848, acc: 0.84375\n",
      "\n",
      "2023-04-23 22:27:46.352323, epoch: 14, train_loss: 313.0561, train_acc: 0.86, test_loss: 217.1808, test_acc: 0.77\n",
      "\n",
      "2023-04-23 22:28:00.524663, 200/782: 0.3543, acc: 0.890625\n",
      "2023-04-23 22:28:14.408919, 400/782: 0.4251, acc: 0.859375\n",
      "2023-04-23 22:28:28.382975, 600/782: 0.2761, acc: 0.890625\n",
      "\n",
      "2023-04-23 22:29:17.655147, epoch: 15, train_loss: 271.4094, train_acc: 0.88, test_loss: 224.5323, test_acc: 0.76\n",
      "\n",
      "2023-04-23 22:29:33.925060, 200/782: 0.3610, acc: 0.875\n",
      "2023-04-23 22:29:48.659994, 400/782: 0.4308, acc: 0.8125\n",
      "2023-04-23 22:30:03.099810, 600/782: 0.1750, acc: 0.953125\n",
      "\n",
      "2023-04-23 22:30:51.616468, epoch: 16, train_loss: 236.6358, train_acc: 0.89, test_loss: 230.9388, test_acc: 0.78\n",
      "\n",
      "2023-04-23 22:31:07.603442, 200/782: 0.2267, acc: 0.90625\n",
      "2023-04-23 22:31:22.711451, 400/782: 0.3092, acc: 0.875\n",
      "2023-04-23 22:31:37.241017, 600/782: 0.2333, acc: 0.921875\n",
      "\n",
      "2023-04-23 22:32:25.691087, epoch: 17, train_loss: 196.9677, train_acc: 0.91, test_loss: 217.6643, test_acc: 0.79\n",
      "\n",
      "2023-04-23 22:32:41.135319, 200/782: 0.2882, acc: 0.859375\n",
      "2023-04-23 22:32:55.722146, 400/782: 0.2844, acc: 0.90625\n",
      "2023-04-23 22:33:09.953132, 600/782: 0.2135, acc: 0.953125\n",
      "\n",
      "2023-04-23 22:34:00.979964, epoch: 18, train_loss: 168.3277, train_acc: 0.92, test_loss: 264.2805, test_acc: 0.75\n",
      "\n",
      "2023-04-23 22:34:16.908534, 200/782: 0.1116, acc: 0.96875\n",
      "2023-04-23 22:34:31.807264, 400/782: 0.1439, acc: 0.921875\n",
      "2023-04-23 22:34:45.712084, 600/782: 0.1036, acc: 0.96875\n",
      "\n",
      "2023-04-23 22:35:38.108862, epoch: 19, train_loss: 135.9203, train_acc: 0.94, test_loss: 296.8457, test_acc: 0.75\n",
      "\n",
      "2023-04-23 22:35:54.153294, 200/782: 0.1364, acc: 0.921875\n",
      "2023-04-23 22:36:09.262755, 400/782: 0.0671, acc: 0.96875\n",
      "2023-04-23 22:36:23.537375, 600/782: 0.1974, acc: 0.953125\n",
      "\n",
      "2023-04-23 22:37:12.569849, epoch: 20, train_loss: 107.5051, train_acc: 0.95, test_loss: 275.2776, test_acc: 0.77\n",
      "\n",
      "2023-04-23 22:37:27.284441, 200/782: 0.0218, acc: 1.0\n",
      "2023-04-23 22:37:41.694630, 400/782: 0.1110, acc: 0.953125\n",
      "2023-04-23 22:37:56.006253, 600/782: 0.1225, acc: 0.96875\n",
      "\n",
      "2023-04-23 22:38:49.425602, epoch: 21, train_loss: 90.0286, train_acc: 0.96, test_loss: 299.9435, test_acc: 0.75\n",
      "\n",
      "2023-04-23 22:39:04.244424, 200/782: 0.0326, acc: 0.984375\n",
      "2023-04-23 22:39:18.322099, 400/782: 0.0232, acc: 1.0\n",
      "2023-04-23 22:39:32.395894, 600/782: 0.0951, acc: 0.96875\n",
      "\n",
      "2023-04-23 22:40:29.115772, epoch: 22, train_loss: 74.0507, train_acc: 0.97, test_loss: 401.6254, test_acc: 0.73\n",
      "\n",
      "2023-04-23 22:40:43.651838, 200/782: 0.0158, acc: 1.0\n",
      "2023-04-23 22:40:57.697383, 400/782: 0.1317, acc: 0.953125\n",
      "2023-04-23 22:41:11.690586, 600/782: 0.0784, acc: 0.984375\n",
      "\n",
      "2023-04-23 22:42:01.388435, epoch: 23, train_loss: 60.5226, train_acc: 0.97, test_loss: 345.7840, test_acc: 0.77\n",
      "\n",
      "2023-04-23 22:42:15.953622, 200/782: 0.0163, acc: 1.0\n",
      "2023-04-23 22:42:31.494169, 400/782: 0.0103, acc: 1.0\n",
      "2023-04-23 22:42:45.669067, 600/782: 0.0384, acc: 1.0\n",
      "\n",
      "2023-04-23 22:43:31.808899, epoch: 24, train_loss: 57.3829, train_acc: 0.98, test_loss: 279.6026, test_acc: 0.79\n",
      "\n",
      "2023-04-23 22:43:45.992129, 200/782: 0.0765, acc: 0.953125\n",
      "2023-04-23 22:44:01.124313, 400/782: 0.0120, acc: 1.0\n",
      "2023-04-23 22:44:15.299129, 600/782: 0.0298, acc: 1.0\n",
      "\n",
      "2023-04-23 22:45:03.497198, epoch: 25, train_loss: 37.1311, train_acc: 0.99, test_loss: 318.1354, test_acc: 0.78\n",
      "\n",
      "2023-04-23 22:45:18.069714, 200/782: 0.0336, acc: 0.984375\n",
      "2023-04-23 22:45:33.038247, 400/782: 0.0101, acc: 1.0\n",
      "2023-04-23 22:45:50.076586, 600/782: 0.1896, acc: 0.9375\n",
      "\n",
      "2023-04-23 22:46:38.573868, epoch: 26, train_loss: 38.6355, train_acc: 0.98, test_loss: 308.9865, test_acc: 0.78\n",
      "\n",
      "2023-04-23 22:46:53.009193, 200/782: 0.0106, acc: 1.0\n",
      "2023-04-23 22:47:07.215939, 400/782: 0.0210, acc: 1.0\n",
      "2023-04-23 22:47:21.751417, 600/782: 0.0140, acc: 0.984375\n",
      "\n",
      "2023-04-23 22:48:16.886609, epoch: 27, train_loss: 41.3992, train_acc: 0.98, test_loss: 372.1271, test_acc: 0.73\n",
      "\n",
      "2023-04-23 22:48:32.395892, 200/782: 0.0029, acc: 1.0\n",
      "2023-04-23 22:48:47.715308, 400/782: 0.0040, acc: 1.0\n",
      "2023-04-23 22:49:01.742648, 600/782: 0.0362, acc: 0.984375\n",
      "\n",
      "2023-04-23 22:49:49.587063, epoch: 28, train_loss: 35.3954, train_acc: 0.99, test_loss: 289.5657, test_acc: 0.78\n",
      "\n",
      "2023-04-23 22:50:05.257999, 200/782: 0.0078, acc: 1.0\n",
      "2023-04-23 22:50:19.670978, 400/782: 0.0026, acc: 1.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023-04-23 22:50:33.631395, 600/782: 0.0246, acc: 0.984375\n",
      "\n",
      "2023-04-23 22:51:20.359464, epoch: 29, train_loss: 22.9649, train_acc: 0.99, test_loss: 302.0734, test_acc: 0.79\n",
      "\n",
      "2023-04-23 22:51:35.176612, 200/782: 0.0075, acc: 1.0\n",
      "2023-04-23 22:51:49.426564, 400/782: 0.0012, acc: 1.0\n",
      "2023-04-23 22:52:03.474307, 600/782: 0.0045, acc: 1.0\n",
      "\n",
      "2023-04-23 22:52:50.013599, epoch: 30, train_loss: 6.8951, train_acc: 1.00, test_loss: 355.9989, test_acc: 0.80\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Set random seeds and deterministic pytorch for reproducibility\n",
    "# random.seed(config.seed)       # python random seed\n",
    "torch.manual_seed(config.seed) # pytorch random seed\n",
    "# numpy.random.seed(config.seed) # numpy random seed\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "# Load the dataset: We're training our CNN on CIFAR10 (https://www.cs.toronto.edu/~kriz/cifar.html)\n",
    "# First we define the tranformations to apply to our images\n",
    "\n",
    "\n",
    "# Initialize our model, recursively go over all modules and convert their parameters and buffers to CUDA tensors (if device is set to cuda)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# WandB – wandb.watch() automatically fetches all layer dimensions, gradients, model parameters and logs them automatically to your dashboard.\n",
    "# Using log=\"all\" log histograms of parameter values in addition to gradients\n",
    "wandb.watch(model, log=\"all\")\n",
    "\n",
    "for epoch in range(1, config.epochs + 1):\n",
    "    train_loss, train_acc, train_errors = train(train_loader, model, criterion, optimizer, device)\n",
    "    wandb.log({\"train_loss\": train_loss, \"train_acc\": train_acc, \"train_errors\": train_errors})\n",
    "    # test_dataloader, model, criterion, device, classes\n",
    "    example_images, test_loss, test_acc, test_errors = test(test_loader, model, criterion, device, classes)\n",
    "    wandb.log({'example_images': example_images, 'test_loss': test_loss, 'test_acc': test_acc, 'test_errors': test_errors})\n",
    "    print()\n",
    "    print(f'{datetime.now()}, epoch: {epoch}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.2f}, test_loss: {test_loss:.4f}, test_acc: {test_acc:.2f}')\n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "333f99ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# WandB – Save the model checkpoint. This automatically saves a file to the cloud and associates it with the current run.\n",
    "torch.save(model.state_dict(), \"model.ckpt\")\n",
    "wandb.save('model.ckpt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
